﻿// -----------------------------------------------------------------------
// <copyright file="Sphere.cs" company="">
// TODO: Update copyright text.
// </copyright>
// -----------------------------------------------------------------------

namespace Assignment_3.Scene.Basic_Shapes
{
    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Text;
    using Assignment_3.GenericStructures;
    using Assignment_3.Raytracer;

    /// <summary>
    /// TODO: Update summary.
    /// </summary>
    public class Sphere : Shape
    {
        Vector center_;
        Double radius_;

        public Sphere(Vector center, Double radius)
        {
            this.center_ = center;
            this.radius_ = radius;
        }

        public override Vector getFirstIntersection(Ray r)
        {
            
            //EO = vector from origin to center of sphere

            Vector EO = this.center_ - r.origin;

            //v = distance from origin to point on direction vector closest to the center of the sphere
            double v = Vector.DotProduct(EO, r.direction);
            double discriminant = Math.Pow(radius_,2)-((Vector.DotProduct(EO,EO)-Math.Pow(v,2)));


            if (discriminant < 0)
            {
                return null;
            }
            else if ((r.origin - this.center_).Length() > radius_)  //outside the circle
            {
                double d = Math.Sqrt(discriminant);
                Vector P = r.origin + ((v - d) * r.direction);
                if (((r.origin - P) / r.direction).x > 0 || ((r.origin - P) / r.direction).y > 0 || ((r.origin - P) / r.direction).z > 0)
                {
                    return null;
                }
                return P;
            }
            else  //inside the circle
            {
                double d = Math.Sqrt(discriminant);
                Vector P = r.origin + ((v + d) * r.direction);
                if (((r.origin - P) / r.direction).x > 0 || ((r.origin - P) / r.direction).y > 0 || ((r.origin - P) / r.direction).z > 0)
                {
                    return null;
                }
                return P;
            }
            

        }

        protected bool PointIntersects(Vector intersectionPoint)
        {
            if(Math.Abs((intersectionPoint-this.center_).Length() - radius_) < 0.001)
            {
                return true;
            }
            else
            {
                return false;
            }
        }

        public override Ray getReflectedRay(Ray original, Vector intersectionPoint)
        {
            Ray reflectedRay = null;
            if(PointIntersects(intersectionPoint))
            {
                Vector normal = calculateNormal(intersectionPoint, original);
                if (original.direction.DotProduct(normal) <=0)
                {
                    double c1 = -Vector.DotProduct(normal, original.direction);
                    reflectedRay = new Ray(intersectionPoint, original.direction + (2 * normal * c1));
                    reflectedRay.direction /= reflectedRay.direction.Length();
                }
            }

            return reflectedRay;
        }

        public override Ray getRefractedRay(Ray original, Vector intersectionPoint, double indexOfRefractionOutside = 1.00, double indexOfRefractionInside=1.00)
        {
            

            Ray refractedRay = null;

            if (properties.transparencyCoefficient < 1)
            {
                if (PointIntersects(intersectionPoint))
                {
                    //determin if going from inside to outside, or outside to inside
                    
                    Vector normal = calculateNormal(intersectionPoint, original);
                    if (original.direction.DotProduct(normal) < 0)
                    {
                        double n = indexOfRefractionOutside / this.properties.indexOfRefraction;
                        double c1 = -Vector.DotProduct(normal, original.direction);
                        double c2 = Math.Sqrt(1 - (n * n) * (1 - (c1 * c1)));

                        refractedRay = new Ray(intersectionPoint, (n * original.direction) + (n * c1 - c2) * normal);
                        refractedRay.direction /= refractedRay.direction.Length();
                    }
                    else  //backfacing
                    {
                        double n = this.properties.indexOfRefraction / indexOfRefractionOutside;
                        normal = new Vector(0, 0, 0) - normal;
                        double c1 = -Vector.DotProduct(normal, original.direction);
                        
                        double c2 = Math.Sqrt(1 - (n * n) * (1 - (c1 * c1)));

                        refractedRay = new Ray(intersectionPoint, (n * original.direction) + (n * c1 - c2) * normal);
                        refractedRay.direction /= refractedRay.direction.Length();
                    }
                }
            }
           

            return refractedRay;
        }

        public override Vector calculateNormal(Vector intersectionPoint, Ray r)
        {
            Vector normal = intersectionPoint - center_;
                        
            double sum = (Math.Sqrt(Math.Pow(normal.x, 2)+Math.Pow(normal.y, 2)+ Math.Pow(normal.z, 2)));

            normal.x /= sum;
            normal.y /= sum;
            normal.z /= sum;

            return normal;
        }


    }
}
